A2C (Advantage Actor-Critic) — Low-Level PyTorch Implementation (CartPole-v1)#
A2C is an on-policy actor-critic algorithm:
the actor learns a policy \(\pi_\theta(a\mid s)\) (how to act)
the critic learns a value function \(V_\phi(s)\) (how good a state is)
the actor is trained with advantages (“better than expected” signals)
This notebook builds the math carefully, then implements A2C with minimal PyTorch (no RL libraries, no high-level training abstractions), using a vectorized Gymnasium environment for synchronous rollouts.
Learning goals#
By the end you should be able to:
derive the A2C update from the policy gradient theorem
explain why the baseline (critic) reduces variance
implement GAE(\(\gamma,\lambda\)) and n-step bootstrapped returns
train an A2C agent on
CartPole-v1and visualize learning with Plotlymap the concepts to Stable-Baselines3 A2C hyperparameters
Notebook roadmap#
A2C intuition + what “advantage” means
Mathematical formulation (LaTeX)
Low-level PyTorch implementation (actor + critic)
Training on CartPole with vectorized rollouts
Plotly diagnostics (returns, losses, policy/value slices)
Stable-Baselines3 A2C reference + hyperparameters
import math
import time
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots
try:
import gymnasium as gym
GYMNASIUM_AVAILABLE = True
except Exception as e:
GYMNASIUM_AVAILABLE = False
_GYM_IMPORT_ERROR = e
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
TORCH_AVAILABLE = True
except Exception as e:
TORCH_AVAILABLE = False
_TORCH_IMPORT_ERROR = e
pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)
assert GYMNASIUM_AVAILABLE, f"gymnasium import failed: {_GYM_IMPORT_ERROR}"
assert TORCH_AVAILABLE, f"torch import failed: {_TORCH_IMPORT_ERROR}"
print('gymnasium', gym.__version__)
print('torch', torch.__version__)
gymnasium 1.1.1
torch 2.7.0+cu126
# --- Run configuration ---
# Keep FAST_RUN=True for a quick demo.
# For a more reliable "solve", set FAST_RUN=False.
FAST_RUN = True
ENV_ID = "CartPole-v1" # discrete actions, small continuous state
SEED = 42
# A2C is usually run with multiple envs in parallel.
N_ENVS = 8 if FAST_RUN else 16
# Rollout horizon per env (A2C commonly uses small n_steps).
N_STEPS = 5
# Total interaction budget
TOTAL_TIMESTEPS = 30_000 if FAST_RUN else 200_000
# Core RL hyperparameters
GAMMA = 0.99
GAE_LAMBDA = 1.0 # 1.0 => classic advantage w/ n-step bootstrapping
# Loss weights
ENT_COEF = 0.01
VF_COEF = 0.5
# Optimization
LR = 7e-4
MAX_GRAD_NORM = 0.5
RMSPROP_EPS = 1e-5
# Optional: normalize advantage each update
NORMALIZE_ADVANTAGE = True
# Network
HIDDEN_SIZES = (128, 128)
# Logging
LOG_EVERY_UPDATES = 50
RETURN_SMOOTHING_WINDOW = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device', DEVICE)
device cpu
/home/tempa/miniconda3/lib/python3.12/site-packages/torch/cuda/__init__.py:174: UserWarning:
CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
1) A2C intuition: actor + critic + advantage#
Actor#
The actor is a stochastic policy \(\pi_\theta(a\mid s)\).
It outputs a distribution over actions.
We sample actions from that distribution to explore.
Critic#
The critic is a value function \(V_\phi(s)\).
It predicts the expected discounted return from state \(s\).
It is trained via regression to match a bootstrapped return target.
Advantage#
The advantage measures how much better an action did compared to what the critic expected:
If \(A(s_t,a_t)\) is positive, the action was better than expected, and the actor should increase its probability.
Why “A2C”?#
A2C is the synchronous version of A3C:
A3C: many workers update parameters asynchronously.
A2C: many workers collect experience in parallel, then we do a single synchronized update.
In practice, A2C typically uses a vectorized environment and batches data as:
2) Mathematical formulation (policy gradient + baseline)#
We model the environment as an MDP \((\mathcal{S}, \mathcal{A}, P, r, \gamma)\).
Return#
The discounted return from time \(t\) is:
Objective#
We want to maximize expected return:
Policy gradient theorem#
A standard form is:
Baseline (variance reduction)#
We can subtract a baseline \(b(s_t)\) without changing the expectation:
Choosing \(b(s_t)=V_\phi(s_t)\) yields the advantage form:
Bootstrapped n-step return#
With a rollout horizon \(T\) (a.k.a. n_steps), we use a bootstrapped target:
Generalized Advantage Estimation (GAE)#
GAE defines the TD residual:
and computes advantages with an exponentially-weighted sum:
\(\lambda=1\) recovers the classic (higher-variance) advantage.
smaller \(\lambda\) reduces variance but increases bias.
Loss functions (minimization form)#
Actor loss (to maximize expected advantage):
Critic loss (value regression):
Entropy bonus (encourage exploration):
Total loss:
def make_vec_env(env_id: str, n_envs: int, seed: int) -> gym.vector.SyncVectorEnv:
env_fns = [lambda: gym.make(env_id) for _ in range(n_envs)]
env = gym.vector.SyncVectorEnv(env_fns, autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
env.reset(seed=[seed + i for i in range(n_envs)])
return env
env = make_vec_env(ENV_ID, N_ENVS, SEED)
obs_space = env.single_observation_space
act_space = env.single_action_space
assert isinstance(act_space, gym.spaces.Discrete), "This notebook's implementation uses discrete actions (Categorical)."
OBS_DIM = int(np.prod(obs_space.shape))
N_ACTIONS = int(act_space.n)
print('obs_space', obs_space)
print('act_space', act_space)
print('OBS_DIM', OBS_DIM, 'N_ACTIONS', N_ACTIONS)
obs_space Box([-4.8 -inf -0.4189 -inf], [4.8 inf 0.4189 inf], (4,), float32)
act_space Discrete(2)
OBS_DIM 4 N_ACTIONS 2
3) Actor-Critic network (low-level PyTorch)#
We use a shared MLP trunk, then two heads:
actor head outputs logits for a categorical distribution
critic head outputs a scalar value \(V(s)\)
This is not the only architecture (you can also use separate networks), but it’s a common and effective baseline.
class ActorCritic(nn.Module):
def __init__(self, obs_dim: int, n_actions: int, hidden_sizes: tuple[int, int] = (128, 128)):
super().__init__()
h1, h2 = hidden_sizes
self.fc1 = nn.Linear(obs_dim, h1)
self.fc2 = nn.Linear(h1, h2)
self.actor = nn.Linear(h2, n_actions)
self.critic = nn.Linear(h2, 1)
def forward(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# obs: (B, obs_dim)
x = torch.tanh(self.fc1(obs))
x = torch.tanh(self.fc2(x))
logits = self.actor(x) # (B, n_actions)
values = self.critic(x).squeeze(-1) # (B,)
return logits, values
def sample_actions_and_logp(logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Low-level categorical sampling without torch.distributions.
Returns:
actions: (B,) int64
logp: (B,) log-prob of sampled action
entropy: (B,) categorical entropy
"""
log_probs = F.log_softmax(logits, dim=-1) # (B, A)
probs = log_probs.exp()
actions = torch.multinomial(probs, num_samples=1).squeeze(-1) # (B,)
logp = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
entropy = -(probs * log_probs).sum(dim=-1)
return actions, logp, entropy
@torch.no_grad()
def policy_action_probs(logits: torch.Tensor) -> torch.Tensor:
return F.softmax(logits, dim=-1)
4) GAE implementation#
We compute advantages backwards in time:
where \(d_t\in\{0,1\}\) is the done flag.
def compute_gae(
rewards: torch.Tensor, # (T, N)
dones: torch.Tensor, # (T, N) float32 {0,1}
values: torch.Tensor, # (T, N)
last_values: torch.Tensor, # (N,)
gamma: float,
gae_lambda: float,
) -> tuple[torch.Tensor, torch.Tensor]:
T, N = rewards.shape
advantages = torch.zeros((T, N), device=rewards.device, dtype=torch.float32)
last_adv = torch.zeros((N,), device=rewards.device, dtype=torch.float32)
for t in reversed(range(T)):
mask = 1.0 - dones[t]
next_values = last_values if t == T - 1 else values[t + 1]
delta = rewards[t] + gamma * mask * next_values - values[t]
last_adv = delta + gamma * gae_lambda * mask * last_adv
advantages[t] = last_adv
returns = advantages + values
return advantages, returns
5) Training loop (A2C)#
Key design choices in this minimal implementation:
Vectorized envs (
n_envs) to match A2C’s synchronous batching.Rollout buffer of shape
(n_steps, n_envs, ...).Compute GAE + bootstrapped returns.
Single gradient update per rollout (no replay buffer, no off-policy corrections).
We also record:
episodic return (score) whenever any env finishes an episode
actor loss, critic loss, entropy, explained variance (optional diagnostic)
def explained_variance(y_true: np.ndarray, y_pred: np.ndarray) -> float:
var_y = np.var(y_true)
if var_y < 1e-12:
return float('nan')
return 1.0 - float(np.var(y_true - y_pred) / var_y)
def train_a2c(
env_id: str,
seed: int,
device: torch.device,
n_envs: int,
n_steps: int,
total_timesteps: int,
gamma: float,
gae_lambda: float,
ent_coef: float,
vf_coef: float,
lr: float,
max_grad_norm: float,
rmsprop_eps: float,
hidden_sizes: tuple[int, int],
normalize_advantage: bool,
log_every_updates: int = 50,
):
torch.manual_seed(seed)
np.random.seed(seed)
env = make_vec_env(env_id, n_envs, seed)
obs_space = env.single_observation_space
act_space = env.single_action_space
obs_dim = int(np.prod(obs_space.shape))
n_actions = int(act_space.n)
model = ActorCritic(obs_dim, n_actions, hidden_sizes=hidden_sizes).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, eps=rmsprop_eps)
# Rollout buffers
obs_buf = torch.zeros((n_steps, n_envs, obs_dim), device=device, dtype=torch.float32)
act_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.int64)
rew_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.float32)
done_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.float32)
val_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.float32)
obs, _ = env.reset(seed=[seed + i for i in range(n_envs)])
# Episode tracking across vector envs
ep_returns_running = np.zeros((n_envs,), dtype=np.float32)
ep_lengths_running = np.zeros((n_envs,), dtype=np.int32)
ep_returns: list[float] = []
ep_lengths: list[int] = []
updates = total_timesteps // (n_envs * n_steps)
history_updates: list[dict] = []
last_adv_flat = None
t0 = time.time()
global_step = 0
model.train()
for update in range(1, updates + 1):
# --- Collect rollout ---
for t in range(n_steps):
obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
obs_buf[t] = obs_t
with torch.no_grad():
logits, values = model(obs_t)
actions, _, _ = sample_actions_and_logp(logits)
act_buf[t] = actions
val_buf[t] = values
next_obs, rewards, terminated, truncated, _ = env.step(actions.cpu().numpy())
dones = np.logical_or(terminated, truncated)
rew_buf[t] = torch.as_tensor(rewards, dtype=torch.float32, device=device)
done_buf[t] = torch.as_tensor(dones, dtype=torch.float32, device=device)
# Episode bookkeeping
ep_returns_running += rewards
ep_lengths_running += 1
for i in np.where(dones)[0]:
ep_returns.append(float(ep_returns_running[i]))
ep_lengths.append(int(ep_lengths_running[i]))
ep_returns_running[i] = 0.0
ep_lengths_running[i] = 0
obs = next_obs
global_step += n_envs
# Bootstrap value from last observation
with torch.no_grad():
obs_last = torch.as_tensor(obs, dtype=torch.float32, device=device)
_, last_values = model(obs_last) # (N,)
advantages, returns = compute_gae(
rewards=rew_buf,
dones=done_buf,
values=val_buf,
last_values=last_values,
gamma=gamma,
gae_lambda=gae_lambda,
)
# Flatten (T, N, ...) -> (T*N, ...)
b_obs = obs_buf.reshape(-1, obs_dim)
b_act = act_buf.reshape(-1)
b_adv = advantages.reshape(-1)
b_ret = returns.reshape(-1)
if normalize_advantage:
b_adv = (b_adv - b_adv.mean()) / (b_adv.std() + 1e-8)
# --- Compute losses ---
logits, values_pred = model(b_obs)
log_probs = F.log_softmax(logits, dim=-1)
probs = log_probs.exp()
b_logp = log_probs.gather(1, b_act.unsqueeze(1)).squeeze(1)
entropy = -(probs * log_probs).sum(dim=-1).mean()
actor_loss = -(b_logp * b_adv.detach()).mean()
critic_loss = 0.5 * F.mse_loss(values_pred, b_ret.detach())
loss = actor_loss + vf_coef * critic_loss - ent_coef * entropy
optimizer.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
last_adv_flat = b_adv.detach().cpu().numpy()
# Diagnostics
y_true = b_ret.detach().cpu().numpy()
y_pred = values_pred.detach().cpu().numpy()
mean_ep_return = float(np.mean(ep_returns[-RETURN_SMOOTHING_WINDOW:])) if len(ep_returns) else float('nan')
history_updates.append(
dict(
update=update,
timesteps=global_step,
actor_loss=float(actor_loss.detach().cpu().item()),
critic_loss=float(critic_loss.detach().cpu().item()),
entropy=float(entropy.detach().cpu().item()),
explained_variance=explained_variance(y_true, y_pred),
episodes=len(ep_returns),
mean_return_window=mean_ep_return,
)
)
if update % log_every_updates == 0 or update == 1 or update == updates:
elapsed = time.time() - t0
print(
f"update {update:>4d}/{updates} | steps {global_step:>7d} | episodes {len(ep_returns):>5d} | "
f"mean_return@{RETURN_SMOOTHING_WINDOW} {mean_ep_return:>7.1f} | "
f"loss {float(loss.detach().cpu()):>8.4f} | {elapsed:>6.1f}s"
)
env.close()
hist_df = pd.DataFrame(history_updates)
return model, hist_df, np.array(ep_returns, dtype=np.float32), np.array(ep_lengths, dtype=np.int32), last_adv_flat
model, hist_df, ep_returns, ep_lengths, last_adv_flat = train_a2c(
env_id=ENV_ID,
seed=SEED,
device=DEVICE,
n_envs=N_ENVS,
n_steps=N_STEPS,
total_timesteps=TOTAL_TIMESTEPS,
gamma=GAMMA,
gae_lambda=GAE_LAMBDA,
ent_coef=ENT_COEF,
vf_coef=VF_COEF,
lr=LR,
max_grad_norm=MAX_GRAD_NORM,
rmsprop_eps=RMSPROP_EPS,
hidden_sizes=HIDDEN_SIZES,
normalize_advantage=NORMALIZE_ADVANTAGE,
log_every_updates=LOG_EVERY_UPDATES,
)
hist_df.tail()
update 1/750 | steps 40 | episodes 0 | mean_return@50 nan | loss 2.6692 | 0.0s
update 50/750 | steps 2000 | episodes 82 | mean_return@50 24.9 | loss 1.6804 | 0.2s
update 100/750 | steps 4000 | episodes 146 | mean_return@50 31.8 | loss 7.0746 | 0.5s
update 150/750 | steps 6000 | episodes 211 | mean_return@50 33.0 | loss 12.9465 | 0.8s
update 200/750 | steps 8000 | episodes 266 | mean_return@50 37.3 | loss 27.3818 | 1.0s
update 250/750 | steps 10000 | episodes 309 | mean_return@50 42.7 | loss 7.7835 | 1.3s
update 300/750 | steps 12000 | episodes 343 | mean_return@50 54.5 | loss 1.4184 | 1.5s
update 350/750 | steps 14000 | episodes 371 | mean_return@50 65.0 | loss 19.3684 | 1.8s
update 400/750 | steps 16000 | episodes 396 | mean_return@50 71.7 | loss 1.0661 | 2.1s
update 450/750 | steps 18000 | episodes 431 | mean_return@50 67.6 | loss 13.2570 | 2.4s
update 500/750 | steps 20000 | episodes 457 | mean_return@50 65.4 | loss 1.0435 | 2.7s
update 550/750 | steps 22000 | episodes 487 | mean_return@50 71.2 | loss 1.0582 | 2.9s
update 600/750 | steps 24000 | episodes 514 | mean_return@50 67.5 | loss 0.9999 | 3.0s
update 650/750 | steps 26000 | episodes 551 | mean_return@50 52.4 | loss 1.1563 | 3.1s
update 700/750 | steps 28000 | episodes 574 | mean_return@50 63.4 | loss 2.1548 | 3.3s
update 750/750 | steps 30000 | episodes 600 | mean_return@50 79.5 | loss 1.1326 | 3.4s
| update | timesteps | actor_loss | critic_loss | entropy | explained_variance | episodes | mean_return_window | |
|---|---|---|---|---|---|---|---|---|
| 745 | 746 | 29840 | -0.141999 | 2.279103 | 0.612243 | -0.273748 | 600 | 79.48 |
| 746 | 747 | 29880 | -0.164299 | 2.443412 | 0.641806 | -0.466463 | 600 | 79.48 |
| 747 | 748 | 29920 | -0.135936 | 3.744340 | 0.590112 | -1.532006 | 600 | 79.48 |
| 748 | 749 | 29960 | -0.104616 | 3.020100 | 0.607867 | -0.825297 | 600 | 79.48 |
| 749 | 750 | 30000 | -0.107253 | 2.492551 | 0.637384 | -0.567428 | 600 | 79.48 |
6) Plot: score (return) per episode#
CartPole gives reward \(+1\) per time step, so episode return = episode length (up to 500).
episodes = np.arange(1, len(ep_returns) + 1)
roll_mean = pd.Series(ep_returns).rolling(RETURN_SMOOTHING_WINDOW).mean().to_numpy()
fig = go.Figure()
fig.add_trace(go.Scatter(x=episodes, y=ep_returns, mode='lines', name='return', line=dict(width=1)))
fig.add_trace(go.Scatter(x=episodes, y=roll_mean, mode='lines', name=f'mean@{RETURN_SMOOTHING_WINDOW}', line=dict(width=3)))
fig.update_layout(
title='A2C on CartPole-v1 — score (return) per episode',
xaxis_title='episode',
yaxis_title='return',
)
fig.show()
7) Plot: training diagnostics (losses, entropy, explained variance)#
Actor loss becomes more negative when advantages are consistently positive for sampled actions.
Critic loss should generally decrease as the value function fits the returns.
Entropy typically decreases as the policy becomes more confident.
Explained variance (rough critic diagnostic) near 1 is good; near 0 means the critic explains little.
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=("Actor loss", "Critic loss", "Entropy", "Explained variance"),
)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['actor_loss'], name='actor_loss'), row=1, col=1)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['critic_loss'], name='critic_loss'), row=1, col=2)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['entropy'], name='entropy'), row=2, col=1)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['explained_variance'], name='explained_variance'), row=2, col=2)
fig.update_layout(height=700, title='A2C training diagnostics', showlegend=False)
fig.update_xaxes(title_text='timesteps')
fig.show()
8) Plot: advantage distribution (last update)#
A2C pushes up the log-probability of actions with positive advantage and pushes down those with negative advantage.
fig = px.histogram(
x=last_adv_flat,
nbins=60,
title='Advantage histogram (last update)',
)
fig.update_layout(xaxis_title='advantage', yaxis_title='count')
fig.show()
9) Visualize the learned policy + value function (2D slice)#
CartPole states are 4D:
To visualize something, we take a 2D slice over pole angle \(\theta\) and pole angular velocity \(\dot{\theta}\), while fixing \(x=0\) and \(\dot{x}=0\).
Left plot: \(\pi(a=1\mid s)\) (probability of pushing right)
Right plot: \(V(s)\) (critic estimate)
@torch.no_grad()
def policy_value_slice(model: nn.Module, device: torch.device, grid_n: int = 70):
model.eval()
angles = np.linspace(-0.21, 0.21, grid_n) # roughly CartPole angle limits
ang_vels = np.linspace(-3.0, 3.0, grid_n)
theta, theta_dot = np.meshgrid(angles, ang_vels)
states = np.zeros((grid_n * grid_n, 4), dtype=np.float32)
states[:, 2] = theta.ravel()
states[:, 3] = theta_dot.ravel()
obs_t = torch.as_tensor(states, dtype=torch.float32, device=device)
logits, values = model(obs_t)
probs = policy_action_probs(logits)
p_right = probs[:, 1].reshape(grid_n, grid_n).cpu().numpy()
v = values.reshape(grid_n, grid_n).cpu().numpy()
return angles, ang_vels, p_right, v
angles, ang_vels, p_right, v = policy_value_slice(model, DEVICE)
fig = make_subplots(
rows=1,
cols=2,
subplot_titles=("Policy: P(push right)", "Critic: V(s)"),
)
fig.add_trace(
go.Heatmap(
x=angles,
y=ang_vels,
z=p_right,
colorscale='RdBu',
zmin=0.0,
zmax=1.0,
colorbar=dict(title='P(right)'),
),
row=1,
col=1,
)
fig.add_trace(
go.Heatmap(
x=angles,
y=ang_vels,
z=v,
colorscale='Viridis',
colorbar=dict(title='V(s)'),
),
row=1,
col=2,
)
fig.update_layout(
height=420,
title='Learned policy/value on a 2D state slice (x=0, xdot=0)',
)
fig.update_xaxes(title_text='pole angle θ', row=1, col=1)
fig.update_yaxes(title_text='pole angular velocity θdot', row=1, col=1)
fig.update_xaxes(title_text='pole angle θ', row=1, col=2)
fig.update_yaxes(title_text='pole angular velocity θdot', row=1, col=2)
fig.show()
10) Quick evaluation (deterministic actions)#
We evaluate by taking the greedy action \(\arg\max_a \pi(a\mid s)\).
@torch.no_grad()
def evaluate(model: nn.Module, env_id: str, n_episodes: int = 10, seed: int = 0):
env = gym.make(env_id)
returns = []
for ep in range(n_episodes):
obs, _ = env.reset(seed=seed + ep)
done = False
ret = 0.0
while not done:
obs_t = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
logits, _ = model(obs_t)
action = int(torch.argmax(logits, dim=-1).item())
obs, reward, terminated, truncated, _ = env.step(action)
done = bool(terminated or truncated)
ret += float(reward)
returns.append(ret)
env.close()
return np.array(returns, dtype=np.float32)
model.eval()
eval_returns = evaluate(model, ENV_ID, n_episodes=10, seed=SEED + 1000)
print('eval returns:', eval_returns)
print('mean ± std:', float(eval_returns.mean()), '±', float(eval_returns.std()))
eval returns: [474. 500. 319. 245. 374. 419. 211. 354. 247. 292.]
mean ± std: 343.5 ± 93.85440826416016
11) Pitfalls + diagnostics#
On-policy constraint: A2C uses data from the current policy. If you reuse old experience without correction, it becomes biased.
Done handling: You must stop bootstrapping across episode boundaries. Here we treat
terminated OR truncatedas terminal for simplicity.Entropy coefficient: Too high keeps the policy random; too low can collapse exploration early.
Critic collapse: If the critic is too weak/strong relative to the actor, learning can become unstable.
Parallel envs matter: With too few envs you get higher-variance updates.
Good quick checks:
returns increase over time
entropy decreases gradually (not instantly)
critic loss decreases and explained variance improves
12) Exercises#
Change \(\lambda\) in GAE (e.g. 0.9) and compare learning curves.
Swap RMSprop for Adam and compare stability.
Implement continuous actions by outputting a Gaussian policy (mean + log-std) and testing on
Pendulum-v1.Add a learning-rate schedule.
Add observation normalization and compare speed.
13) Stable-Baselines3 A2C reference implementation (web research)#
Stable-Baselines3 (SB3) includes an A2C implementation.
Docs page: https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html
Minimal usage#
from stable_baselines3 import A2C
import gymnasium as gym
env = gym.make("CartPole-v1")
model = A2C(
policy="MlpPolicy",
env=env,
learning_rate=7e-4,
n_steps=5,
gamma=0.99,
gae_lambda=1.0,
ent_coef=0.0,
vf_coef=0.5,
max_grad_norm=0.5,
rms_prop_eps=1e-5,
use_rms_prop=True,
normalize_advantage=False,
)
model.learn(total_timesteps=200_000)
SB3 A2C hyperparameters (signature + meaning)#
From the SB3 docs, the constructor signature is:
A2C(policy, env, learning_rate=0.0007, n_steps=5, gamma=0.99, gae_lambda=1.0,
ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, rms_prop_eps=1e-05,
use_rms_prop=True, use_sde=False, sde_sample_freq=-1,
rollout_buffer_class=None, rollout_buffer_kwargs=None,
normalize_advantage=False, stats_window_size=100, tensorboard_log=None,
policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)
Parameter meanings (SB3 docs):
policy: policy class (e.g.MlpPolicy,CnnPolicy)env: environment (Gym env, VecEnv, or registered env id string)learning_rate: float or schedulen_steps: rollout length per env (batch size =n_steps * n_env)gamma: discount factorgae_lambda: bias/variance trade-off for GAE;1.0equals classic advantageent_coef: entropy coefficientvf_coef: value loss coefficientmax_grad_norm: gradient clipping thresholdrms_prop_eps: RMSprop epsilonuse_rms_prop: use RMSprop (default) vs Adamuse_sde: generalized State Dependent Exploration (gSDE)sde_sample_freq: resample gSDE noise every n steps (-1= only at rollout start)rollout_buffer_class: custom rollout buffer classrollout_buffer_kwargs: kwargs for rollout buffernormalize_advantage: normalize advantagesstats_window_size: episodes window for logging averagestensorboard_log: tensorboard log dirpolicy_kwargs: kwargs for policy network/architectureverbose: verbosity levelseed: random seeddevice:cpu,cuda, orauto_init_setup_model: build the network immediately
References#
Mnih et al. (2016), Asynchronous Methods for Deep Reinforcement Learning (A3C)
Schulman et al. (2016), High-Dimensional Continuous Control Using Generalized Advantage Estimation
Stable-Baselines3 A2C docs: https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html